import pandas as pd
from sklearn.model_selection import train_test_split
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
import nltk
import re
import os
import numpy as np
from nltk import download
from gensim.models import LdaMulticore, CoherenceModel
from gensim.corpora.dictionary import Dictionary
from nltk.tokenize import word_tokenize
import matplotlib.pyplot as plt
import pickle

df = pd.read_csv('/zfs/disinfo/dcweekly_contrast/story_text/dc_weekly_topic_scores.csv')

# Make a binary indicator for whether we are in the AI-generated text period of time
df['period'] = df.month.apply(lambda x: 1- int(x=='2023-06'))

# Explicitly set the NLTK data directory
nltk_data_dir = os.path.join('/', 'scratch', 'cehrett', 'nltk_data')
os.environ['NLTK_DATA'] = nltk_data_dir

# Ensure the directory is added to NLTK's search path
nltk.data.path.append(nltk_data_dir)

# Download the necessary packages
nltk.download('stopwords', download_dir=nltk_data_dir)
nltk.download('wordnet', download_dir=nltk_data_dir)

# Download necessary NLTK data
download('stopwords', download_dir=os.path.join('/', 'scratch', 'cehrett', 'nltk_data'))
download('wordnet', download_dir=os.path.join('/', 'scratch', 'cehrett', 'nltk_data'))
download('punkt', download_dir=os.path.join('/', 'scratch', 'cehrett', 'nltk_data'))

# Function to clean and preprocess text
def preprocess_text(text):
    # Remove HTML tags
    text = re.sub(r'<.*?>', '', text)
    # Remove punctuation and numbers
    text = re.sub(r'[^a-zA-Z\s]', '', text)
    # Lowercase
    text = text.lower()
    # Remove stopwords
    stop_words = set(stopwords.words('english'))
    text = ' '.join([word for word in text.split() if word not in stop_words])
    # Lemmatization
    lemmatizer = WordNetLemmatizer()
    text = ' '.join([lemmatizer.lemmatize(word) for word in text.split()])
    return text

# Preprocess the text in the DataFrame
df['processed_text'] = df['clean_text'].apply(preprocess_text)

# Split the DataFrame into two categories based on the 'period' column
df_period_0 = df[df['period'] == 0]
df_period_1 = df[df['period'] == 1]

# Tokenize the texts
df_period_0['tokens'] = df_period_0['processed_text'].apply(word_tokenize)
df_period_1['tokens'] = df_period_1['processed_text'].apply(word_tokenize)

# Create Dictionaries and Corpora
dictionary_0 = Dictionary(df_period_0['tokens'])
corpus_0 = [dictionary_0.doc2bow(text) for text in df_period_0['tokens']]

dictionary_1 = Dictionary(df_period_1['tokens'])
corpus_1 = [dictionary_1.doc2bow(text) for text in df_period_1['tokens']]

def compute_coherence_values(dictionary, corpus, texts, start, limit, step):
    """
    Computes coherence scores for LDA models with varying number of topics.
    
    This function iterates over a specified range of topic numbers, fitting an LDA model for each count,
    and calculates the coherence score to evaluate the model's quality. It uses the LdaMulticore class
    for efficient parallel computation. Key hyperparameters are chosen to balance model complexity,
    convergence speed, and coherence quality:
    
    - `num_topics`: Varies in steps from `start` to `limit` to explore the optimal number of topics.
    - `random_state`: Ensures reproducibility of results.
    - `chunksize`: Number of documents to process at a time in the training algorithm. A balance between
      memory consumption and computational efficiency.
    - `passes`: The number of passes through the corpus during training, for sufficient model convergence.
    - `alpha`: Set to 'asymmetric' to account for the intuition that some topics are more probable than others.
    - `eta`: Left as `None` to use the default symmetric prior, assuming no prior knowledge about topic-word distribution.
    - `decay` and `offset`: Control the learning rate, helping to stabilize early iterations of the model training.
    - `eval_every`: Determines how often to calculate log perplexity; set to optimize the trade-off between
      training speed and model evaluation.
    - `iterations`: Max number of iterations over each document; higher for better model precision.
    - `gamma_threshold`: Convergence threshold for gamma, the document-topic density.
    - `minimum_probability`: Filters out topics with a probability below this threshold in the topics-per-document distribution.
    - `minimum_phi_value`: Lower bound on term probabilities, filtering out terms.
    - `per_word_topics`: If True, computes a list of most likely topics for each word, alongside its phi value times feature length.
    - `workers`: Utilizes 7 of 8 available cores, leaving one free to prevent resource saturation.
    
    Parameters:
        dictionary (gensim.corpora.dictionary.Dictionary): The dictionary mapping of id->word.
        corpus (list of list of (int, float)): Corpus represented as a list of document vectors or sparse matrix.
        texts (list of list of str): Text data for coherence calculation; should match `corpus`.
        start (int): The starting number of topics.
        limit (int): The max number of topics to model.
        step (int): The step size to iterate through the number of topics.
    
    Returns:
        model_list (list of gensim.models.LdaMulticore): List of LDA multicore models for each number of topics.
        coherence_values (list of float): Coherence values corresponding to each model. 
        The `c_v` coherence measure is chosen for its balance between interpretability and informativeness. 
        Unlike other coherence measures that may rely solely on document co-occurrence data (`u_mass`) 
        or on similarity between top words (`c_uci`, `c_npmi`), `c_v` combines a variety of data sources, 
        including word co-occurrence and document-level information, to provide a robust assessment of topic quality. 
        This measure has been shown to correlate well with human judgments of topic coherence, 
        making it a suitable choice for evaluating the meaningfulness and distinctiveness of the topics discovered by LDA models. 
        Additionally, `c_v` allows for the incorporation of sliding window techniques, 
        which can capture more nuanced semantic relationships between words, further enhancing the evaluation of topic coherence.

    """
    coherence_values = []
    model_list = []
    for num_topics in range(start, limit, step):
        print(f'Beginning LDA with {num_topics} topics')
        model = LdaMulticore(corpus=corpus,
                             id2word=dictionary,
                             num_topics=num_topics,
                             random_state=100,
                             chunksize=100,
                             passes=10,
                             alpha='asymmetric',
                             eta=None,
                             decay=0.5,
                             offset=64,
                             eval_every=10,
                             iterations=400,
                             gamma_threshold=0.001,
                             minimum_probability=0.01,
                             minimum_phi_value=0.01,
                             per_word_topics=True,
                             workers=6
                            )  
        model_list.append(model)
        print('LDA model fitted. Beginning coherence calculation')
        coherencemodel = CoherenceModel(model=model, texts=texts, dictionary=dictionary, coherence='c_v')
        coherence_values.append(coherencemodel.get_coherence())

    return model_list, coherence_values


# Compute coherence values for each set of topics
model_list_0, coherence_values_0 = compute_coherence_values(dictionary=dictionary_0, corpus=corpus_0, texts=df_period_0['tokens'], start=4, limit=153, step=4)

# Compute coherence values for each set of topics
model_list_1, coherence_values_1 = compute_coherence_values(dictionary=dictionary_1, corpus=corpus_1, texts=df_period_1['tokens'], start=4, limit=153, step=4)

results = {
    'set_0': {
        'models': model_list_0,
        'coherence_values': coherence_values_0
    },
    'set_1': {
        'models': model_list_1,
        'coherence_values': coherence_values_1
    }
}

# Save the packaged object
with open('lda_results.pkl', 'wb') as f:
    pickle.dump(results, f)